{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MAXP 2021初赛数据探索和处理-4\n",
    "\n",
    "把原始数据的标签转换成数字形式，并完成Train/Validation/Test的分割。这里的划分是用于比赛模型训练和模型选择用的，并不是原始的文件名。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "import dgl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path\n",
    "base_path = '/Users/jamezhan/PycharmProjects/MAXP/final_dataset'\n",
    "publish_path = 'publish'\n",
    "\n",
    "nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取节点列表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3655452, 4)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>node_idx</th>\n",
       "      <th>paper_id</th>\n",
       "      <th>Label</th>\n",
       "      <th>Split_ID</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3655448</th>\n",
       "      <td>3655448</td>\n",
       "      <td>caed47d55d1e193ecb1fa97a415c13dd</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3655449</th>\n",
       "      <td>3655449</td>\n",
       "      <td>c82eb6be79a245392fb626b9a7e1f246</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3655450</th>\n",
       "      <td>3655450</td>\n",
       "      <td>926a31f6b378575204aae30b5dfa6dd3</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3655451</th>\n",
       "      <td>3655451</td>\n",
       "      <td>bbace2419c3f827158ea4602f3eb35fa</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         node_idx                          paper_id Label  Split_ID\n",
       "3655448   3655448  caed47d55d1e193ecb1fa97a415c13dd   NaN         1\n",
       "3655449   3655449  c82eb6be79a245392fb626b9a7e1f246   NaN         1\n",
       "3655450   3655450  926a31f6b378575204aae30b5dfa6dd3   NaN         1\n",
       "3655451   3655451  bbace2419c3f827158ea4602f3eb35fa   NaN         1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})\n",
    "print(nodes_df.shape)\n",
    "nodes_df.tail(4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 转换标签为数字"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(23, 3)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>node_idx</th>\n",
       "      <th>paper_id</th>\n",
       "      <th>Split_ID</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Label</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>A</th>\n",
       "      <td>2670</td>\n",
       "      <td>2670</td>\n",
       "      <td>2670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>B</th>\n",
       "      <td>65303</td>\n",
       "      <td>65303</td>\n",
       "      <td>65303</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>C</th>\n",
       "      <td>111502</td>\n",
       "      <td>111502</td>\n",
       "      <td>111502</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>D</th>\n",
       "      <td>104005</td>\n",
       "      <td>104005</td>\n",
       "      <td>104005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>E</th>\n",
       "      <td>45014</td>\n",
       "      <td>45014</td>\n",
       "      <td>45014</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>F</th>\n",
       "      <td>32876</td>\n",
       "      <td>32876</td>\n",
       "      <td>32876</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>G</th>\n",
       "      <td>43452</td>\n",
       "      <td>43452</td>\n",
       "      <td>43452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>H</th>\n",
       "      <td>71824</td>\n",
       "      <td>71824</td>\n",
       "      <td>71824</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>I</th>\n",
       "      <td>23994</td>\n",
       "      <td>23994</td>\n",
       "      <td>23994</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>J</th>\n",
       "      <td>25241</td>\n",
       "      <td>25241</td>\n",
       "      <td>25241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>K</th>\n",
       "      <td>32762</td>\n",
       "      <td>32762</td>\n",
       "      <td>32762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>L</th>\n",
       "      <td>53391</td>\n",
       "      <td>53391</td>\n",
       "      <td>53391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>M</th>\n",
       "      <td>83971</td>\n",
       "      <td>83971</td>\n",
       "      <td>83971</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>N</th>\n",
       "      <td>103472</td>\n",
       "      <td>103472</td>\n",
       "      <td>103472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>O</th>\n",
       "      <td>17593</td>\n",
       "      <td>17593</td>\n",
       "      <td>17593</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>P</th>\n",
       "      <td>52166</td>\n",
       "      <td>52166</td>\n",
       "      <td>52166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Q</th>\n",
       "      <td>19676</td>\n",
       "      <td>19676</td>\n",
       "      <td>19676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>R</th>\n",
       "      <td>32610</td>\n",
       "      <td>32610</td>\n",
       "      <td>32610</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>S</th>\n",
       "      <td>24609</td>\n",
       "      <td>24609</td>\n",
       "      <td>24609</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T</th>\n",
       "      <td>20878</td>\n",
       "      <td>20878</td>\n",
       "      <td>20878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>U</th>\n",
       "      <td>24740</td>\n",
       "      <td>24740</td>\n",
       "      <td>24740</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>V</th>\n",
       "      <td>39557</td>\n",
       "      <td>39557</td>\n",
       "      <td>39557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>W</th>\n",
       "      <td>13111</td>\n",
       "      <td>13111</td>\n",
       "      <td>13111</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       node_idx  paper_id  Split_ID\n",
       "Label                              \n",
       "A          2670      2670      2670\n",
       "B         65303     65303     65303\n",
       "C        111502    111502    111502\n",
       "D        104005    104005    104005\n",
       "E         45014     45014     45014\n",
       "F         32876     32876     32876\n",
       "G         43452     43452     43452\n",
       "H         71824     71824     71824\n",
       "I         23994     23994     23994\n",
       "J         25241     25241     25241\n",
       "K         32762     32762     32762\n",
       "L         53391     53391     53391\n",
       "M         83971     83971     83971\n",
       "N        103472    103472    103472\n",
       "O         17593     17593     17593\n",
       "P         52166     52166     52166\n",
       "Q         19676     19676     19676\n",
       "R         32610     32610     32610\n",
       "S         24609     24609     24609\n",
       "T         20878     20878     20878\n",
       "U         24740     24740     24740\n",
       "V         39557     39557     39557\n",
       "W         13111     13111     13111"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 先检查一下标签的分布\n",
    "label_dist = nodes_df.groupby(by='Label').count()\n",
    "print(label_dist.shape)\n",
    "label_dist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 可以看到一共有23个标签，A类最少，C类最多，基本每类都有几万个。下面从0开始，重够标签\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>node_idx</th>\n",
       "      <th>paper_id</th>\n",
       "      <th>Label</th>\n",
       "      <th>Split_ID</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>bfdee5ab86ef5e68da974d48a138c28e</td>\n",
       "      <td>S</td>\n",
       "      <td>0</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>78f43b8b62f040347fec0be44e5f08bd</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>a971601a0286d2701aa5cde46e63a9fd</td>\n",
       "      <td>G</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>ac4b88a72146bae66cedfd1c13e1552d</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   node_idx                          paper_id Label  Split_ID  label\n",
       "0         0  bfdee5ab86ef5e68da974d48a138c28e     S         0     18\n",
       "1         1  78f43b8b62f040347fec0be44e5f08bd   NaN         0     -1\n",
       "2         2  a971601a0286d2701aa5cde46e63a9fd     G         0      6\n",
       "3         3  ac4b88a72146bae66cedfd1c13e1552d   NaN         0     -1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 按A-W的顺序，从0开始转换\n",
    "for i, l in enumerate(label_dist.index.to_list()):\n",
    "    nodes_df.loc[(nodes_df.Label==l), 'label'] = i\n",
    "\n",
    "nodes_df.label.fillna(-1, inplace=True)\n",
    "nodes_df.label = nodes_df.label.astype('int')\n",
    "nodes_df.head(4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 只保留新的node index、标签和原始的分割标签"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>node_idx</th>\n",
       "      <th>label</th>\n",
       "      <th>Split_ID</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3655448</th>\n",
       "      <td>3655448</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3655449</th>\n",
       "      <td>3655449</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3655450</th>\n",
       "      <td>3655450</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3655451</th>\n",
       "      <td>3655451</td>\n",
       "      <td>-1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         node_idx  label  Split_ID\n",
       "3655448   3655448     -1         1\n",
       "3655449   3655449     -1         1\n",
       "3655450   3655450     -1         1\n",
       "3655451   3655451     -1         1"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nodes = nodes_df[['node_idx', 'label', 'Split_ID']]\n",
    "nodes.tail(4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 划分Train/Validation/Test\n",
    "\n",
    "由于只有原始的Train_nodes文件里面包括了标签，所以这里的Train/Validation是对原始的分割。\n",
    "\n",
    "这里按照9:1的比例划分Train/Validation。Test就是原来的validation_nodes里面的index。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取所有的标签\n",
    "tr_val_labels_df = nodes[(nodes.Split_ID == 0) & (nodes.label >= 0)]\n",
    "test_label_df = nodes[nodes.Split_ID == 1]\n",
    "\n",
    "# 按照0~22每个标签划分train/validation\n",
    "tr_labels_idx = np.array([0])\n",
    "val_labels_idx = np.array([0])\n",
    "split_ratio = 0.9\n",
    "\n",
    "for label in range(23):\n",
    "    label_idx = tr_val_labels_df[tr_val_labels_df.label == label].node_idx.to_numpy()\n",
    "    split_point = int(label_idx.shape[0] * split_ratio)\n",
    "    \n",
    "    # 把每个标签的train和validation的index添加到整个列表\n",
    "    tr_labels_idx = np.append(tr_labels_idx, label_idx[: split_point])\n",
    "    val_labels_idx = np.append(val_labels_idx, label_idx[split_point: ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取Train/Validation/Test标签index\n",
    "tr_labels_idx = tr_labels_idx[1: ]\n",
    "val_labels_idx = val_labels_idx[1: ]\n",
    "\n",
    "test_labels_idx = test_label_df.node_idx.to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取完整的标签列表\n",
    "labels = nodes.label.to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存标签以及Train/Validation/Test的index为二进制格式方便后面建模时的快速读取\n",
    "label_path = os.path.join(base_path, publish_path, 'labels.pkl')\n",
    "\n",
    "with open(label_path, 'wb') as f:\n",
    "    pickle.dump({'tr_label_idx': tr_labels_idx, \n",
    "                 'val_label_idx': val_labels_idx, \n",
    "                 'test_label_idx': test_labels_idx,\n",
    "                 'label': labels}, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dgl]",
   "language": "python",
   "name": "conda-env-dgl-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
